import cv2
import os
import sys
package_path = "/system/lib"
if package_path not in sys.path:
    sys.path.insert(0,package_path)

import numpy as np
import serial

def postprocess(outputs):
    outputs = list(outputs.values())[0]
    pred_idx = outputs.argmax()
    return pred_idx

def calculate_bias(original_image):
    # Apply fixed thresholding for binarization
    _, binary = cv2.threshold(original_image, 115, 255, cv2.THRESH_BINARY)
    
    # Apply morphological operations to clean up the image
    kernel = np.ones((5, 5), np.uint8)
    cleaned_image = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    cleaned_image = cv2.morphologyEx(cleaned_image, cv2.MORPH_OPEN, kernel)
    
    height = cleaned_image.shape[0]
    width = cleaned_image.shape[1]
    
    # Calculate image midpoint
    mid_x = width // 2
    mid_y = height // 2
    
    # Get all white pixel coordinates
    white_pixels = np.where(cleaned_image > 0)
    y_coords = white_pixels[0]
    x_coords = white_pixels[1]
    
    if len(x_coords) > 0:
        # Create points array
        points = np.column_stack((x_coords, y_coords))
        
        # Fit a convex hull around the points
        hull = cv2.convexHull(points)
        
        # Extract x and y coordinates from hull points
        hull_points = hull.reshape(-1, 2)
        hull_x = hull_points[:, 0]
        hull_y = hull_points[:, 1]
        
        # Separate hull points into left and right based on median x-coordinate
        median_x = np.median(hull_x)
        left_mask = hull_x < median_x
        right_mask = hull_x >= median_x
        
        # Get left and right points
        left_x = hull_x[left_mask]
        left_y = hull_y[left_mask]
        right_x = hull_x[right_mask]
        right_y = hull_y[right_mask]
        
        # Fit lines
        if len(left_y) > 1 and len(right_y) > 1:
            left_fit = np.polyfit(left_y, left_x, 1)
            right_fit = np.polyfit(right_y, right_x, 1)
            
            # Keep image in grayscale
            original_image = cleaned_image.copy()
            
            # Calculate line points for drawing
            y1, y2 = 0, height
            left_x1 = int(left_fit[0] * y1 + left_fit[1])
            left_x2 = int(left_fit[0] * y2 + left_fit[1])
            right_x1 = int(right_fit[0] * y1 + right_fit[1])
            right_x2 = int(right_fit[0] * y2 + right_fit[1])
            
            # Draw left and right lines
            cv2.line(original_image, (left_x1, y1), (left_x2, y2), (255, 0, 0), 2)  # Blue
            cv2.line(original_image, (right_x1, y1), (right_x2, y2), (0, 255, 0), 2)  # Green
            
            # Calculate and draw midline
            mid_x1 = (left_x1 + right_x1) // 2
            mid_x2 = (left_x2 + right_x2) // 2
            cv2.line(original_image, (mid_x1, y1), (mid_x2, y2), (0, 0, 255), 2)  # Red
            
            # Draw image center point as "×"
            cv2.drawMarker(original_image, (mid_x, mid_y), (255, 255, 0), cv2.MARKER_TILTED_CROSS, 20, 2)
            
            # Calculate bias at the middle height
            mid_x_at_height = (int(left_fit[0] * mid_y + left_fit[1]) + 
                             int(right_fit[0] * mid_y + right_fit[1])) // 2
            
            # Draw line between center point and midline at mid_height
            cv2.line(original_image, (mid_x, mid_y), (mid_x_at_height, mid_y), (255, 255, 0), 2)  # Yellow
            
            # Calculate normalized bias (-100 to 100)
            bias = -100 * (mid_x_at_height - mid_x) / (width/2)
            bias = np.clip(bias, -100, 100)
            
            return bias, original_image
            
        else:
            return None, cv2.cvtColor(cleaned_image, cv2.COLOR_GRAY2BGR)
    else:
        return None, cv2.cvtColor(cleaned_image, cv2.COLOR_GRAY2BGR)

def analyze_horizontal_line(cleaned_image, original_image):
    height, width = cleaned_image.shape
    
    # Convert cleaned image to BGR for colored visualization
    cleaned_image_color = cv2.cvtColor(cleaned_image, cv2.COLOR_GRAY2BGR)
    
    # Find rows where at least 60% of the pixels are zero
    zero_threshold = 0.6 * width
    zero_rows = np.where(np.sum(cleaned_image == 0, axis=1) >= zero_threshold)[0]
    
    if len(zero_rows) >= 2:
        # Mark the first and last zero rows as boundaries
        first_zero_row = zero_rows[0]
        last_zero_row = zero_rows[-1]
        
        # Count zero pixels above first row and below last row
        zeros_above = np.sum(cleaned_image[:first_zero_row] == 0)
        zeros_below = np.sum(cleaned_image[last_zero_row:] == 0)
        
        # Draw lines on the boundaries
        cv2.line(cleaned_image_color, (0, first_zero_row), (width, first_zero_row), (255, 0, 0), 2)
        cv2.line(cleaned_image_color, (0, last_zero_row), (width, last_zero_row), (0, 255, 0), 2)
        
        cv2.putText(cleaned_image_color, f'Zeros above: {zeros_above}', (10, first_zero_row - 10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
        cv2.putText(cleaned_image_color, f'Zeros below: {zeros_below}', (10, last_zero_row + 25), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
        
        cv2.imshow('Fitting Results', cleaned_image_color)
        cv2.waitKey(1)
        original_image[:] = cleaned_image_color[:]
        
        # Determine if it's start or end line based on zero pixel distribution
        if zeros_above > zeros_below:
            print("Detected: Start Line (more zeros above)")
            return 0  # Start line
        else:
            print("Detected: End Line (more zeros below)")
            return 1  # End line
    
    print("Detected: Straight Line (no clear boundary rows)")
    return 2  # No clear boundary rows found

def infer_from_camera():
    # Open the camera
    cap = cv2.VideoCapture(0)
    
    if not cap.isOpened():
        print("Error: Could not open camera.")
        return

    try:
        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            if not ret:
                print("Error: Could not read frame.")
                break

            # Convert to grayscale
            gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            # Apply fixed thresholding
            _, binary = cv2.threshold(gray_frame, 127, 255, cv2.THRESH_BINARY)
            
            # Apply morphological operations to clean up the image
            kernel = np.ones((5, 5), np.uint8)
            cleaned_image = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
            cleaned_image = cv2.morphologyEx(cleaned_image, cv2.MORPH_OPEN, kernel)
            
            # Get pred_idx from horizontal line analysis
            pred_idx = analyze_horizontal_line(cleaned_image, frame.copy())
            print(f"pred_idx: {pred_idx}")
            
            # Calculate bias
            bias, annotated_frame = calculate_bias(binary)

            # Create a 2x2 visualization grid
            height, width = frame.shape[:2]
            canvas = np.ones((height * 2, width * 2, 3), dtype=np.uint8) * 255
            
            # Convert binary frame to 3 channels for display
            binary_3ch = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
            cleaned_image_3ch = cv2.cvtColor(cleaned_image, cv2.COLOR_GRAY2BGR)
            
            # Place images in 2x2 grid
            canvas[:height, :width] = frame  # Top-left: Original
            canvas[:height, width:] = binary_3ch  # Top-right: Preprocessed
            canvas[height:, :width] = cleaned_image_3ch  # Bottom-left: Cleaned
            canvas[height:, width:] = annotated_frame  # Bottom-right: Annotated

            # Add labels
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(canvas, 'Original', (10, 30), font, 1, (0, 0, 0), 2)
            cv2.putText(canvas, 'Preprocessed', (width + 10, 30), font, 1, (0, 0, 0), 2)
            cv2.putText(canvas, 'Cleaned', (10, height + 30), font, 1, (0, 0, 0), 2)
            cv2.putText(canvas, 'Annotated', (width + 10, height + 30), font, 1, (0, 0, 0), 2)

            if bias is not None:

                                # Configure serial port
                port = '/dev/serial/by-id/usb-1a86_USB_Serial-if00-port0'
                baudrate = 115200
                ser = serial.Serial(port, baudrate, timeout=1)

                if pred_idx == 0:
                    TS_class = 0
                    class_text = "Start"
                elif pred_idx == 1:
                    TS_class = 1
                    class_text = "End"
                elif pred_idx == 2 and bias < 0:
                    TS_class = 2
                    class_text = "Left"
                elif pred_idx == 2 and bias >= 0:
                    TS_class = 3
                    class_text = "Right"
                else:
                    TS_class = 1
                    class_text = "End"

                TS_pos = int(abs(bias))
                TS_buffer = 'C' + chr(TS_class) + 'P' + chr(TS_pos)

                if 'ser' in locals() and ser.is_open:
                    data_to_send = TS_buffer.encode('utf-8')
                    if TS_pos <= 127 and TS_class <= 127:
                        ser.write(data_to_send)
                else:
                    print(f"无法打开串口 {port}")

                # Add text with results
                text_y = height * 2 - 30
                cv2.putText(canvas, f'Class: {pred_idx}', (10, text_y), font, 0.7, (0, 0, 0), 2)
                cv2.putText(canvas, f'Bias: {bias:.2f}', (width + 10, text_y), font, 0.7, (0, 0, 0), 2)

            # Display the canvas
            cv2.imshow('Analysis Results', canvas)
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    finally:
        cap.release()
        cv2.destroyAllWindows()

if __name__ == '__main__':
    infer_from_camera()